backend_api.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. from __future__ import annotations
  2. import json
  3. import flask
  4. import os
  5. import logging
  6. import asyncio
  7. import shutil
  8. import random
  9. import datetime
  10. from flask import Flask, Response, request, jsonify, render_template
  11. from typing import Generator
  12. from pathlib import Path
  13. from urllib.parse import quote_plus
  14. from hashlib import sha256
  15. from werkzeug.utils import secure_filename
  16. try:
  17. from flask_limiter import Limiter
  18. from flask_limiter.util import get_remote_address
  19. has_flask_limiter = True
  20. except ImportError:
  21. has_flask_limiter = False
  22. from ...image import is_allowed_extension, to_image
  23. from ...client.service import convert_to_provider
  24. from ...providers.asyncio import to_sync_generator
  25. from ...client.helper import filter_markdown
  26. from ...tools.files import supports_filename, get_streaming, get_bucket_dir, get_buckets
  27. from ...tools.run_tools import iter_run_tools
  28. from ...errors import ProviderNotFoundError
  29. from ...cookies import get_cookies_dir
  30. from ... import ChatCompletion
  31. from ... import models
  32. from .api import Api
  33. logger = logging.getLogger(__name__)
  34. def safe_iter_generator(generator: Generator) -> Generator:
  35. start = next(generator)
  36. def iter_generator():
  37. yield start
  38. yield from generator
  39. return iter_generator()
  40. class Backend_Api(Api):
  41. """
  42. Handles various endpoints in a Flask application for backend operations.
  43. This class provides methods to interact with models, providers, and to handle
  44. various functionalities like conversations, error handling, and version management.
  45. Attributes:
  46. app (Flask): A Flask application instance.
  47. routes (dict): A dictionary mapping API endpoints to their respective handlers.
  48. """
  49. def __init__(self, app: Flask) -> None:
  50. """
  51. Initialize the backend API with the given Flask application.
  52. Args:
  53. app (Flask): Flask application instance to attach routes to.
  54. """
  55. self.app: Flask = app
  56. if has_flask_limiter and app.demo:
  57. limiter = Limiter(
  58. get_remote_address,
  59. app=app,
  60. default_limits=["200 per day", "50 per hour"],
  61. storage_uri="memory://",
  62. auto_check=False,
  63. strategy="moving-window",
  64. )
  65. if has_flask_limiter and app.demo:
  66. @app.route('/', methods=['GET'])
  67. @limiter.exempt
  68. def home():
  69. return render_template('demo.html')
  70. else:
  71. @app.route('/', methods=['GET'])
  72. def home():
  73. return render_template('home.html')
  74. @app.route('/backend-api/v2/models', methods=['GET'])
  75. def jsonify_models(**kwargs):
  76. response = get_demo_models() if app.demo else self.get_models(**kwargs)
  77. if isinstance(response, list):
  78. return jsonify(response)
  79. return response
  80. @app.route('/backend-api/v2/models/<provider>', methods=['GET'])
  81. def jsonify_provider_models(**kwargs):
  82. response = self.get_provider_models(**kwargs)
  83. if isinstance(response, list):
  84. return jsonify(response)
  85. return response
  86. @app.route('/backend-api/v2/providers', methods=['GET'])
  87. def jsonify_providers(**kwargs):
  88. response = self.get_providers(**kwargs)
  89. if isinstance(response, list):
  90. return jsonify(response)
  91. return response
  92. def get_demo_models():
  93. return [{
  94. "name": model.name,
  95. "image": isinstance(model, models.ImageModel),
  96. "vision": isinstance(model, models.VisionModel),
  97. "providers": [
  98. getattr(provider, "parent", provider.__name__)
  99. for provider in providers
  100. ],
  101. "demo": True
  102. }
  103. for model, providers in models.demo_models.values()]
  104. def handle_conversation():
  105. """
  106. Handles conversation requests and streams responses back.
  107. Returns:
  108. Response: A Flask response object for streaming.
  109. """
  110. kwargs = {}
  111. if "files[]" in request.files:
  112. images = []
  113. for file in request.files.getlist('files[]'):
  114. if file.filename != '' and is_allowed_extension(file.filename):
  115. images.append((to_image(file.stream, file.filename.endswith('.svg')), file.filename))
  116. kwargs['images'] = images
  117. if "json" in request.form:
  118. json_data = json.loads(request.form['json'])
  119. else:
  120. json_data = request.json
  121. if app.demo and json_data.get("provider") not in ["Custom", "Feature"]:
  122. model = json_data.get("model")
  123. if model != "default" and model in models.demo_models:
  124. json_data["provider"] = random.choice(models.demo_models[model][1])
  125. else:
  126. if not model or model == "default":
  127. json_data["model"] = models.demo_models["default"][0].name
  128. json_data["provider"] = random.choice(models.demo_models["default"][1])
  129. kwargs = self._prepare_conversation_kwargs(json_data, kwargs)
  130. return self.app.response_class(
  131. self._create_response_stream(
  132. kwargs,
  133. json_data.get("conversation_id"),
  134. json_data.get("provider"),
  135. json_data.get("download_images", True),
  136. ),
  137. mimetype='text/event-stream'
  138. )
  139. if has_flask_limiter and app.demo:
  140. @app.route('/backend-api/v2/conversation', methods=['POST'])
  141. @limiter.limit("2 per minute")
  142. def _handle_conversation():
  143. limiter.check()
  144. return handle_conversation()
  145. else:
  146. @app.route('/backend-api/v2/conversation', methods=['POST'])
  147. def _handle_conversation():
  148. return handle_conversation()
  149. @app.route('/backend-api/v2/usage', methods=['POST'])
  150. def add_usage():
  151. cache_dir = Path(get_cookies_dir()) / ".usage"
  152. cache_file = cache_dir / f"{datetime.date.today()}.jsonl"
  153. cache_dir.mkdir(parents=True, exist_ok=True)
  154. with cache_file.open("a" if cache_file.exists() else "w") as f:
  155. f.write(f"{json.dumps(request.json)}\n")
  156. return {}
  157. @app.route('/backend-api/v2/log', methods=['POST'])
  158. def add_log():
  159. cache_dir = Path(get_cookies_dir()) / ".logging"
  160. cache_file = cache_dir / f"{datetime.date.today()}.jsonl"
  161. cache_dir.mkdir(parents=True, exist_ok=True)
  162. data = {"origin": request.headers.get("origin"), **request.json}
  163. with cache_file.open("a" if cache_file.exists() else "w") as f:
  164. f.write(f"{json.dumps(data)}\n")
  165. return {}
  166. @app.route('/backend-api/v2/memory/<user_id>', methods=['POST'])
  167. def add_memory(user_id: str):
  168. api_key = request.headers.get("x_api_key")
  169. json_data = request.json
  170. from mem0 import MemoryClient
  171. client = MemoryClient(api_key=api_key)
  172. client.add(
  173. [{"role": item["role"], "content": item["content"]} for item in json_data.get("items")],
  174. user_id=user_id,
  175. metadata={"conversation_id": json_data.get("id")}
  176. )
  177. return {"count": len(json_data.get("items"))}
  178. @app.route('/backend-api/v2/memory/<user_id>', methods=['GET'])
  179. def read_memory(user_id: str):
  180. api_key = request.headers.get("x_api_key")
  181. from mem0 import MemoryClient
  182. client = MemoryClient(api_key=api_key)
  183. if request.args.get("search"):
  184. return client.search(
  185. request.args.get("search"),
  186. user_id=user_id,
  187. filters=json.loads(request.args.get("filters", "null")),
  188. metadata=json.loads(request.args.get("metadata", "null"))
  189. )
  190. return client.get_all(
  191. user_id=user_id,
  192. page=request.args.get("page", 1),
  193. page_size=request.args.get("page_size", 100),
  194. filters=json.loads(request.args.get("filters", "null")),
  195. )
  196. self.routes = {
  197. '/backend-api/v2/version': {
  198. 'function': self.get_version,
  199. 'methods': ['GET']
  200. },
  201. '/backend-api/v2/synthesize/<provider>': {
  202. 'function': self.handle_synthesize,
  203. 'methods': ['GET']
  204. },
  205. '/images/<path:name>': {
  206. 'function': self.serve_images,
  207. 'methods': ['GET']
  208. }
  209. }
  210. @app.route('/backend-api/v2/create', methods=['GET', 'POST'])
  211. def create():
  212. try:
  213. tool_calls = [{
  214. "function": {
  215. "name": "bucket_tool"
  216. },
  217. "type": "function"
  218. }]
  219. web_search = request.args.get("web_search")
  220. if web_search:
  221. tool_calls.append({
  222. "function": {
  223. "name": "search_tool",
  224. "arguments": {"query": web_search, "instructions": "", "max_words": 1000} if web_search != "true" else {}
  225. },
  226. "type": "function"
  227. })
  228. do_filter_markdown = request.args.get("filter_markdown")
  229. cache_id = request.args.get('cache')
  230. parameters = {
  231. "model": request.args.get("model"),
  232. "messages": [{"role": "user", "content": request.args.get("prompt")}],
  233. "provider": request.args.get("provider", None),
  234. "stream": not do_filter_markdown and not cache_id,
  235. "ignore_stream": not request.args.get("stream"),
  236. "tool_calls": tool_calls,
  237. }
  238. if cache_id:
  239. cache_id = sha256(cache_id.encode() + json.dumps(parameters, sort_keys=True).encode()).hexdigest()
  240. cache_dir = Path(get_cookies_dir()) / ".scrape_cache" / "create"
  241. cache_file = cache_dir / f"{quote_plus(request.args.get('prompt').strip()[:20])}.{cache_id}.txt"
  242. if cache_file.exists():
  243. with cache_file.open("r") as f:
  244. response = f.read()
  245. else:
  246. response = iter_run_tools(ChatCompletion.create, **parameters)
  247. cache_dir.mkdir(parents=True, exist_ok=True)
  248. with cache_file.open("w") as f:
  249. for chunk in response:
  250. f.write(str(chunk))
  251. else:
  252. response = iter_run_tools(ChatCompletion.create, **parameters)
  253. if do_filter_markdown:
  254. return Response(filter_markdown("".join([str(chunk) for chunk in response]), do_filter_markdown), mimetype='text/plain')
  255. def cast_str():
  256. for chunk in response:
  257. if not isinstance(chunk, Exception):
  258. yield str(chunk)
  259. return Response(cast_str(), mimetype='text/plain')
  260. except Exception as e:
  261. logger.exception(e)
  262. return jsonify({"error": {"message": f"{type(e).__name__}: {e}"}}), 500
  263. @app.route('/backend-api/v2/buckets', methods=['GET'])
  264. def list_buckets():
  265. try:
  266. buckets = get_buckets()
  267. if buckets is None:
  268. return jsonify({"error": {"message": "Error accessing bucket directory"}}), 500
  269. sanitized_buckets = [secure_filename(b) for b in buckets]
  270. return jsonify(sanitized_buckets), 200
  271. except Exception as e:
  272. return jsonify({"error": {"message": str(e)}}), 500
  273. @app.route('/backend-api/v2/files/<bucket_id>', methods=['GET', 'DELETE'])
  274. def manage_files(bucket_id: str):
  275. bucket_id = secure_filename(bucket_id)
  276. bucket_dir = get_bucket_dir(bucket_id)
  277. if not os.path.isdir(bucket_dir):
  278. return jsonify({"error": {"message": "Bucket directory not found"}}), 404
  279. if request.method == 'DELETE':
  280. try:
  281. shutil.rmtree(bucket_dir)
  282. return jsonify({"message": "Bucket deleted successfully"}), 200
  283. except OSError as e:
  284. return jsonify({"error": {"message": f"Error deleting bucket: {str(e)}"}}), 500
  285. except Exception as e:
  286. return jsonify({"error": {"message": str(e)}}), 500
  287. delete_files = request.args.get('delete_files', True)
  288. refine_chunks_with_spacy = request.args.get('refine_chunks_with_spacy', False)
  289. event_stream = 'text/event-stream' in request.headers.get('Accept', '')
  290. mimetype = "text/event-stream" if event_stream else "text/plain";
  291. return Response(get_streaming(bucket_dir, delete_files, refine_chunks_with_spacy, event_stream), mimetype=mimetype)
  292. @self.app.route('/backend-api/v2/files/<bucket_id>', methods=['POST'])
  293. def upload_files(bucket_id: str):
  294. bucket_id = secure_filename(bucket_id)
  295. bucket_dir = get_bucket_dir(bucket_id)
  296. os.makedirs(bucket_dir, exist_ok=True)
  297. filenames = []
  298. for file in request.files.getlist('files[]'):
  299. try:
  300. filename = secure_filename(file.filename)
  301. if supports_filename(filename):
  302. with open(os.path.join(bucket_dir, filename), 'wb') as f:
  303. shutil.copyfileobj(file.stream, f)
  304. filenames.append(filename)
  305. finally:
  306. file.stream.close()
  307. with open(os.path.join(bucket_dir, "files.txt"), 'w') as f:
  308. [f.write(f"{filename}\n") for filename in filenames]
  309. return {"bucket_id": bucket_id, "files": filenames}
  310. @app.route('/backend-api/v2/files/<bucket_id>/<filename>', methods=['PUT'])
  311. def upload_file(bucket_id, filename):
  312. bucket_id = secure_filename(bucket_id)
  313. bucket_dir = get_bucket_dir(bucket_id)
  314. filename = secure_filename(filename)
  315. bucket_path = Path(bucket_dir)
  316. if not supports_filename(filename):
  317. return jsonify({"error": {"message": f"File type not allowed"}}), 400
  318. if not bucket_path.exists():
  319. bucket_path.mkdir(parents=True, exist_ok=True)
  320. try:
  321. file_path = bucket_path / filename
  322. file_data = request.get_data()
  323. if not file_data:
  324. return jsonify({"error": {"message": "No file data received"}}), 400
  325. with file_path.open('wb') as f:
  326. f.write(file_data)
  327. return jsonify({"message": f"File '{filename}' uploaded successfully to bucket '{bucket_id}'"}), 201
  328. except Exception as e:
  329. return jsonify({"error": {"message": f"Error uploading file: {str(e)}"}}), 500
  330. @app.route('/backend-api/v2/upload_cookies', methods=['POST'])
  331. def upload_cookies(self):
  332. file = None
  333. if "file" in request.files:
  334. file = request.files['file']
  335. if file.filename == '':
  336. return 'No selected file', 400
  337. if file and file.filename.endswith(".json") or file.filename.endswith(".har"):
  338. filename = secure_filename(file.filename)
  339. file.save(os.path.join(get_cookies_dir(), filename))
  340. return "File saved", 200
  341. return 'Not supported file', 400
  342. def handle_synthesize(self, provider: str):
  343. try:
  344. provider_handler = convert_to_provider(provider)
  345. except ProviderNotFoundError:
  346. return "Provider not found", 404
  347. if not hasattr(provider_handler, "synthesize"):
  348. return "Provider doesn't support synthesize", 500
  349. response_data = provider_handler.synthesize({**request.args})
  350. if asyncio.iscoroutinefunction(provider_handler.synthesize):
  351. response_data = asyncio.run(response_data)
  352. else:
  353. if hasattr(response_data, "__aiter__"):
  354. response_data = to_sync_generator(response_data)
  355. response_data = safe_iter_generator(response_data)
  356. content_type = getattr(provider_handler, "synthesize_content_type", "application/octet-stream")
  357. response = flask.Response(response_data, content_type=content_type)
  358. response.headers['Cache-Control'] = "max-age=604800"
  359. return response
  360. def get_provider_models(self, provider: str):
  361. api_key = request.headers.get("x_api_key")
  362. api_base = request.headers.get("x_api_base")
  363. models = super().get_provider_models(provider, api_key, api_base)
  364. if models is None:
  365. return "Provider not found", 404
  366. return models
  367. def _format_json(self, response_type: str, content = None, **kwargs) -> str:
  368. """
  369. Formats and returns a JSON response.
  370. Args:
  371. response_type (str): The type of the response.
  372. content: The content to be included in the response.
  373. Returns:
  374. str: A JSON formatted string.
  375. """
  376. return json.dumps(super()._format_json(response_type, content, **kwargs)) + "\n"