backend_api.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448
  1. from __future__ import annotations
  2. import json
  3. import flask
  4. import os
  5. import logging
  6. import asyncio
  7. import shutil
  8. import random
  9. import datetime
  10. import tempfile
  11. from flask import Flask, Response, redirect, request, jsonify, render_template, send_from_directory
  12. from werkzeug.exceptions import NotFound
  13. from typing import Generator
  14. from pathlib import Path
  15. from urllib.parse import quote_plus
  16. from hashlib import sha256
  17. from ...client.service import convert_to_provider
  18. from ...providers.asyncio import to_sync_generator
  19. from ...client.helper import filter_markdown
  20. from ...tools.files import supports_filename, get_streaming, get_bucket_dir, get_buckets
  21. from ...tools.run_tools import iter_run_tools
  22. from ...errors import ProviderNotFoundError
  23. from ...image import is_allowed_extension
  24. from ...cookies import get_cookies_dir
  25. from ...image.copy_images import secure_filename, get_source_url, images_dir
  26. from ... import ChatCompletion
  27. from ... import models
  28. from .api import Api
  29. logger = logging.getLogger(__name__)
  30. def safe_iter_generator(generator: Generator) -> Generator:
  31. start = next(generator)
  32. def iter_generator():
  33. yield start
  34. yield from generator
  35. return iter_generator()
  36. class Backend_Api(Api):
  37. """
  38. Handles various endpoints in a Flask application for backend operations.
  39. This class provides methods to interact with models, providers, and to handle
  40. various functionalities like conversations, error handling, and version management.
  41. Attributes:
  42. app (Flask): A Flask application instance.
  43. routes (dict): A dictionary mapping API endpoints to their respective handlers.
  44. """
  45. def __init__(self, app: Flask) -> None:
  46. """
  47. Initialize the backend API with the given Flask application.
  48. Args:
  49. app (Flask): Flask application instance to attach routes to.
  50. """
  51. self.app: Flask = app
  52. self.chat_cache = {}
  53. if app.demo:
  54. @app.route('/', methods=['GET'])
  55. def home():
  56. client_id = os.environ.get("OAUTH_CLIENT_ID", "ed074164-4f8d-4fb2-8bec-44952707965e")
  57. backend_url = os.environ.get("G4F_BACKEND_URL", "")
  58. return render_template('demo.html', backend_url=backend_url, client_id=client_id)
  59. else:
  60. @app.route('/', methods=['GET'])
  61. def home():
  62. return render_template('home.html')
  63. @app.route('/qrcode', methods=['GET'])
  64. @app.route('/qrcode/<conversation_id>', methods=['GET'])
  65. def qrcode(conversation_id: str = ""):
  66. share_url = os.environ.get("G4F_SHARE_URL", "")
  67. return render_template('qrcode.html', conversation_id=conversation_id, share_url=share_url)
  68. @app.route('/backend-api/v2/models', methods=['GET'])
  69. def jsonify_models(**kwargs):
  70. response = get_demo_models() if app.demo else self.get_models(**kwargs)
  71. if isinstance(response, list):
  72. return jsonify(response)
  73. return response
  74. @app.route('/backend-api/v2/models/<provider>', methods=['GET'])
  75. def jsonify_provider_models(**kwargs):
  76. response = self.get_provider_models(**kwargs)
  77. if isinstance(response, list):
  78. return jsonify(response)
  79. return response
  80. @app.route('/backend-api/v2/providers', methods=['GET'])
  81. def jsonify_providers(**kwargs):
  82. response = self.get_providers(**kwargs)
  83. if isinstance(response, list):
  84. return jsonify(response)
  85. return response
  86. def get_demo_models():
  87. return [{
  88. "name": model.name,
  89. "image": isinstance(model, models.ImageModel),
  90. "vision": isinstance(model, models.VisionModel),
  91. "providers": [
  92. getattr(provider, "parent", provider.__name__)
  93. for provider in providers
  94. ],
  95. "demo": True
  96. }
  97. for model, providers in models.demo_models.values()]
  98. def handle_conversation():
  99. """
  100. Handles conversation requests and streams responses back.
  101. Returns:
  102. Response: A Flask response object for streaming.
  103. """
  104. if "json" in request.form:
  105. json_data = json.loads(request.form['json'])
  106. else:
  107. json_data = request.json
  108. if "files" in request.files:
  109. media = []
  110. for file in request.files.getlist('files'):
  111. if file.filename != '' and is_allowed_extension(file.filename):
  112. newfile = tempfile.TemporaryFile()
  113. shutil.copyfileobj(file.stream, newfile)
  114. media.append((newfile, file.filename))
  115. json_data['media'] = media
  116. if app.demo and not json_data.get("provider"):
  117. model = json_data.get("model")
  118. if model != "default" and model in models.demo_models:
  119. json_data["provider"] = random.choice(models.demo_models[model][1])
  120. else:
  121. json_data["provider"] = models.HuggingFace
  122. kwargs = self._prepare_conversation_kwargs(json_data)
  123. return self.app.response_class(
  124. self._create_response_stream(
  125. kwargs,
  126. json_data.get("conversation_id"),
  127. json_data.get("provider"),
  128. json_data.get("download_media", True),
  129. ),
  130. mimetype='text/event-stream'
  131. )
  132. @app.route('/backend-api/v2/conversation', methods=['POST'])
  133. def _handle_conversation():
  134. return handle_conversation()
  135. @app.route('/backend-api/v2/usage', methods=['POST'])
  136. def add_usage():
  137. cache_dir = Path(get_cookies_dir()) / ".usage"
  138. cache_file = cache_dir / f"{datetime.date.today()}.jsonl"
  139. cache_dir.mkdir(parents=True, exist_ok=True)
  140. with cache_file.open("a" if cache_file.exists() else "w") as f:
  141. f.write(f"{json.dumps(request.json)}\n")
  142. return {}
  143. @app.route('/backend-api/v2/log', methods=['POST'])
  144. def add_log():
  145. cache_dir = Path(get_cookies_dir()) / ".logging"
  146. cache_file = cache_dir / f"{datetime.date.today()}.jsonl"
  147. cache_dir.mkdir(parents=True, exist_ok=True)
  148. data = {"origin": request.headers.get("origin"), **request.json}
  149. with cache_file.open("a" if cache_file.exists() else "w") as f:
  150. f.write(f"{json.dumps(data)}\n")
  151. return {}
  152. @app.route('/backend-api/v2/memory/<user_id>', methods=['POST'])
  153. def add_memory(user_id: str):
  154. api_key = request.headers.get("x_api_key")
  155. json_data = request.json
  156. from mem0 import MemoryClient
  157. client = MemoryClient(api_key=api_key)
  158. client.add(
  159. [{"role": item["role"], "content": item["content"]} for item in json_data.get("items")],
  160. user_id=user_id,
  161. metadata={"conversation_id": json_data.get("id")}
  162. )
  163. return {"count": len(json_data.get("items"))}
  164. @app.route('/backend-api/v2/memory/<user_id>', methods=['GET'])
  165. def read_memory(user_id: str):
  166. api_key = request.headers.get("x_api_key")
  167. from mem0 import MemoryClient
  168. client = MemoryClient(api_key=api_key)
  169. if request.args.get("search"):
  170. return client.search(
  171. request.args.get("search"),
  172. user_id=user_id,
  173. filters=json.loads(request.args.get("filters", "null")),
  174. metadata=json.loads(request.args.get("metadata", "null"))
  175. )
  176. return client.get_all(
  177. user_id=user_id,
  178. page=request.args.get("page", 1),
  179. page_size=request.args.get("page_size", 100),
  180. filters=json.loads(request.args.get("filters", "null")),
  181. )
  182. self.routes = {
  183. '/backend-api/v2/version': {
  184. 'function': self.get_version,
  185. 'methods': ['GET']
  186. },
  187. '/backend-api/v2/synthesize/<provider>': {
  188. 'function': self.handle_synthesize,
  189. 'methods': ['GET']
  190. },
  191. '/images/<path:name>': {
  192. 'function': self.serve_images,
  193. 'methods': ['GET']
  194. },
  195. '/media/<path:name>': {
  196. 'function': self.serve_images,
  197. 'methods': ['GET']
  198. }
  199. }
  200. @app.route('/backend-api/v2/create', methods=['GET', 'POST'])
  201. def create():
  202. try:
  203. tool_calls = [{
  204. "function": {
  205. "name": "bucket_tool"
  206. },
  207. "type": "function"
  208. }]
  209. web_search = request.args.get("web_search")
  210. if web_search:
  211. tool_calls.append({
  212. "function": {
  213. "name": "search_tool",
  214. "arguments": {"query": web_search, "instructions": "", "max_words": 1000} if web_search != "true" else {}
  215. },
  216. "type": "function"
  217. })
  218. do_filter_markdown = request.args.get("filter_markdown")
  219. cache_id = request.args.get('cache')
  220. parameters = {
  221. "model": request.args.get("model"),
  222. "messages": [{"role": "user", "content": request.args.get("prompt")}],
  223. "provider": request.args.get("provider", None),
  224. "stream": not do_filter_markdown and not cache_id,
  225. "ignore_stream": not request.args.get("stream"),
  226. "tool_calls": tool_calls,
  227. }
  228. if cache_id:
  229. cache_id = sha256(cache_id.encode() + json.dumps(parameters, sort_keys=True).encode()).hexdigest()
  230. cache_dir = Path(get_cookies_dir()) / ".scrape_cache" / "create"
  231. cache_file = cache_dir / f"{quote_plus(request.args.get('prompt').strip()[:20])}.{cache_id}.txt"
  232. if cache_file.exists():
  233. with cache_file.open("r") as f:
  234. response = f.read()
  235. else:
  236. response = iter_run_tools(ChatCompletion.create, **parameters)
  237. cache_dir.mkdir(parents=True, exist_ok=True)
  238. with cache_file.open("w") as f:
  239. for chunk in response:
  240. f.write(str(chunk))
  241. else:
  242. response = iter_run_tools(ChatCompletion.create, **parameters)
  243. if do_filter_markdown:
  244. return Response(filter_markdown("".join([str(chunk) for chunk in response]), do_filter_markdown), mimetype='text/plain')
  245. def cast_str():
  246. for chunk in response:
  247. if not isinstance(chunk, Exception):
  248. yield str(chunk)
  249. return Response(cast_str(), mimetype='text/plain')
  250. except Exception as e:
  251. logger.exception(e)
  252. return jsonify({"error": {"message": f"{type(e).__name__}: {e}"}}), 500
  253. @app.route('/backend-api/v2/files/<bucket_id>', methods=['GET', 'DELETE'])
  254. def manage_files(bucket_id: str):
  255. bucket_id = secure_filename(bucket_id)
  256. bucket_dir = get_bucket_dir(bucket_id)
  257. if not os.path.isdir(bucket_dir):
  258. return jsonify({"error": {"message": "Bucket directory not found"}}), 404
  259. if request.method == 'DELETE':
  260. try:
  261. shutil.rmtree(bucket_dir)
  262. return jsonify({"message": "Bucket deleted successfully"}), 200
  263. except OSError as e:
  264. return jsonify({"error": {"message": f"Error deleting bucket: {str(e)}"}}), 500
  265. except Exception as e:
  266. return jsonify({"error": {"message": str(e)}}), 500
  267. delete_files = request.args.get('delete_files', True)
  268. refine_chunks_with_spacy = request.args.get('refine_chunks_with_spacy', False)
  269. event_stream = 'text/event-stream' in request.headers.get('Accept', '')
  270. mimetype = "text/event-stream" if event_stream else "text/plain";
  271. return Response(get_streaming(bucket_dir, delete_files, refine_chunks_with_spacy, event_stream), mimetype=mimetype)
  272. @self.app.route('/backend-api/v2/files/<bucket_id>', methods=['POST'])
  273. def upload_files(bucket_id: str):
  274. bucket_id = secure_filename(bucket_id)
  275. bucket_dir = get_bucket_dir(bucket_id)
  276. media_dir = os.path.join(bucket_dir, "media")
  277. os.makedirs(bucket_dir, exist_ok=True)
  278. filenames = []
  279. media = []
  280. for file in request.files.getlist('files'):
  281. try:
  282. filename = secure_filename(file.filename)
  283. if is_allowed_extension(filename):
  284. os.makedirs(media_dir, exist_ok=True)
  285. newfile = os.path.join(media_dir, filename)
  286. media.append(filename)
  287. elif supports_filename(filename):
  288. newfile = os.path.join(bucket_dir, filename)
  289. filenames.append(filename)
  290. else:
  291. continue
  292. with open(newfile, 'wb') as f:
  293. shutil.copyfileobj(file.stream, f)
  294. finally:
  295. file.stream.close()
  296. with open(os.path.join(bucket_dir, "files.txt"), 'w') as f:
  297. [f.write(f"{filename}\n") for filename in filenames]
  298. return {"bucket_id": bucket_id, "files": filenames, "media": media}
  299. @app.route('/files/<bucket_id>/media/<filename>', methods=['GET'])
  300. def get_media(bucket_id, filename, dirname: str = None):
  301. media_dir = get_bucket_dir(dirname, bucket_id, "media")
  302. try:
  303. return send_from_directory(os.path.abspath(media_dir), filename)
  304. except NotFound:
  305. source_url = get_source_url(request.query_string.decode())
  306. if source_url is not None:
  307. return redirect(source_url)
  308. raise
  309. @app.route('/search/<search>', methods=['GET'])
  310. def find_media(search: str):
  311. search = [secure_filename(chunk.lower()) for chunk in search.split("+")]
  312. if not os.access(images_dir, os.R_OK):
  313. return jsonify({"error": {"message": "Not found"}}), 404
  314. match_files = {}
  315. for root, _, files in os.walk(images_dir):
  316. for file in files:
  317. mime_type = is_allowed_extension(file)
  318. if mime_type is not None:
  319. mime_type = secure_filename(mime_type)
  320. for tag in search:
  321. if tag in mime_type:
  322. match_files[file] = match_files.get(file, 0) + 1
  323. break
  324. for tag in search:
  325. if tag in file.lower():
  326. match_files[file] = match_files.get(file, 0) + 1
  327. match_files = [file for file, count in match_files.items() if count >= request.args.get("min", len(search))]
  328. if int(request.args.get("skip", 0)) >= len(match_files):
  329. return jsonify({"error": {"message": "Not found"}}), 404
  330. if (request.args.get("random", False)):
  331. return redirect(f"/media/{random.choice(match_files)}"), 302
  332. return redirect(f"/media/{match_files[int(request.args.get('skip', 0))]}", 302)
  333. @app.route('/backend-api/v2/upload_cookies', methods=['POST'])
  334. def upload_cookies():
  335. file = None
  336. if "file" in request.files:
  337. file = request.files['file']
  338. if file.filename == '':
  339. return 'No selected file', 400
  340. if file and file.filename.endswith(".json") or file.filename.endswith(".har"):
  341. filename = secure_filename(file.filename)
  342. file.save(os.path.join(get_cookies_dir(), filename))
  343. return "File saved", 200
  344. return 'Not supported file', 400
  345. @self.app.route('/backend-api/v2/chat/<share_id>', methods=['GET'])
  346. def get_chat(share_id: str) -> str:
  347. share_id = secure_filename(share_id)
  348. if self.chat_cache.get(share_id, 0) == int(request.headers.get("if-none-match", 0)):
  349. return jsonify({"error": {"message": "Not modified"}}), 304
  350. file = get_bucket_dir(share_id, "chat.json")
  351. if not os.path.isfile(file):
  352. return jsonify({"error": {"message": "Not found"}}), 404
  353. with open(file, 'r') as f:
  354. chat_data = json.load(f)
  355. if chat_data.get("updated", 0) == int(request.headers.get("if-none-match", 0)):
  356. return jsonify({"error": {"message": "Not modified"}}), 304
  357. self.chat_cache[share_id] = chat_data.get("updated", 0)
  358. return jsonify(chat_data), 200
  359. @self.app.route('/backend-api/v2/chat/<share_id>', methods=['POST'])
  360. def upload_chat(share_id: str) -> dict:
  361. chat_data = {**request.json}
  362. updated = chat_data.get("updated", 0)
  363. cache_value = self.chat_cache.get(share_id, 0)
  364. if updated == cache_value:
  365. return jsonify({"error": {"message": "invalid date"}}), 400
  366. share_id = secure_filename(share_id)
  367. bucket_dir = get_bucket_dir(share_id)
  368. os.makedirs(bucket_dir, exist_ok=True)
  369. with open(os.path.join(bucket_dir, "chat.json"), 'w') as f:
  370. json.dump(chat_data, f)
  371. self.chat_cache[share_id] = updated
  372. return {"share_id": share_id}
  373. def handle_synthesize(self, provider: str):
  374. try:
  375. provider_handler = convert_to_provider(provider)
  376. except ProviderNotFoundError:
  377. return "Provider not found", 404
  378. if not hasattr(provider_handler, "synthesize"):
  379. return "Provider doesn't support synthesize", 500
  380. response_data = provider_handler.synthesize({**request.args})
  381. if asyncio.iscoroutinefunction(provider_handler.synthesize):
  382. response_data = asyncio.run(response_data)
  383. else:
  384. if hasattr(response_data, "__aiter__"):
  385. response_data = to_sync_generator(response_data)
  386. response_data = safe_iter_generator(response_data)
  387. content_type = getattr(provider_handler, "synthesize_content_type", "application/octet-stream")
  388. response = flask.Response(response_data, content_type=content_type)
  389. response.headers['Cache-Control'] = "max-age=604800"
  390. return response
  391. def get_provider_models(self, provider: str):
  392. api_key = request.headers.get("x_api_key")
  393. api_base = request.headers.get("x_api_base")
  394. models = super().get_provider_models(provider, api_key, api_base)
  395. if models is None:
  396. return "Provider not found", 404
  397. return models
  398. def _format_json(self, response_type: str, content = None, **kwargs) -> str:
  399. """
  400. Formats and returns a JSON response.
  401. Args:
  402. response_type (str): The type of the response.
  403. content: The content to be included in the response.
  404. Returns:
  405. str: A JSON formatted string.
  406. """
  407. return json.dumps(super()._format_json(response_type, content, **kwargs)) + "\n"