Cloudflare.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. from __future__ import annotations
  2. import asyncio
  3. import json
  4. from pathlib import Path
  5. from ..typing import AsyncResult, Messages, Cookies
  6. from .base_provider import AsyncGeneratorProvider, ProviderModelMixin, get_running_loop
  7. from ..requests import Session, StreamSession, get_args_from_nodriver, raise_for_status, merge_cookies
  8. from ..requests import DEFAULT_HEADERS, has_nodriver, has_curl_cffi
  9. from ..providers.response import FinishReason
  10. from ..cookies import get_cookies_dir
  11. from ..errors import ResponseStatusError, ModelNotFoundError
  12. class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin):
  13. label = "Cloudflare AI"
  14. url = "https://playground.ai.cloudflare.com"
  15. api_endpoint = "https://playground.ai.cloudflare.com/api/inference"
  16. models_url = "https://playground.ai.cloudflare.com/api/models"
  17. working = True
  18. supports_stream = True
  19. supports_system_message = True
  20. supports_message_history = True
  21. default_model = "@cf/meta/llama-3.3-70b-instruct-fp8-fast"
  22. model_aliases = {
  23. "llama-2-7b": "@cf/meta/llama-2-7b-chat-fp16",
  24. "llama-2-7b": "@cf/meta/llama-2-7b-chat-int8",
  25. "llama-3-8b": "@cf/meta/llama-3-8b-instruct",
  26. "llama-3-8b": "@cf/meta/llama-3-8b-instruct-awq",
  27. "llama-3-8b": "@hf/meta-llama/meta-llama-3-8b-instruct",
  28. "llama-3.1-8b": "@cf/meta/llama-3.1-8b-instruct-awq",
  29. "llama-3.1-8b": "@cf/meta/llama-3.1-8b-instruct-fp8",
  30. "llama-3.2-1b": "@cf/meta/llama-3.2-1b-instruct",
  31. "qwen-1.5-7b": "@cf/qwen/qwen1.5-7b-chat-awq",
  32. }
  33. _args: dict = None
  34. @classmethod
  35. def get_cache_file(cls) -> Path:
  36. return Path(get_cookies_dir()) / f"auth_{cls.parent if hasattr(cls, 'parent') else cls.__name__}.json"
  37. @classmethod
  38. def get_models(cls) -> str:
  39. if not cls.models:
  40. if cls._args is None:
  41. if has_nodriver:
  42. get_running_loop(check_nested=True)
  43. args = get_args_from_nodriver(cls.url)
  44. cls._args = asyncio.run(args)
  45. elif not has_curl_cffi:
  46. return cls.models
  47. else:
  48. cls._args = {"headers": DEFAULT_HEADERS, "cookies": {}}
  49. with Session(**cls._args) as session:
  50. response = session.get(cls.models_url)
  51. cls._args["cookies"] = merge_cookies(cls._args["cookies"], response)
  52. try:
  53. raise_for_status(response)
  54. except ResponseStatusError:
  55. return cls.models
  56. json_data = response.json()
  57. cls.models = [model.get("name") for model in json_data.get("models")]
  58. return cls.models
  59. @classmethod
  60. async def create_async_generator(
  61. cls,
  62. model: str,
  63. messages: Messages,
  64. proxy: str = None,
  65. max_tokens: int = 2048,
  66. cookies: Cookies = None,
  67. timeout: int = 300,
  68. **kwargs
  69. ) -> AsyncResult:
  70. cache_file = cls.get_cache_file()
  71. if cls._args is None:
  72. if cache_file.exists():
  73. with cache_file.open("r") as f:
  74. cls._args = json.load(f)
  75. if has_nodriver:
  76. cls._args = await get_args_from_nodriver(cls.url, proxy, timeout, cookies)
  77. else:
  78. cls._args = {"headers": DEFAULT_HEADERS, "cookies": {}}
  79. try:
  80. model = cls.get_model(model)
  81. except ModelNotFoundError:
  82. pass
  83. data = {
  84. "messages": messages,
  85. "lora": None,
  86. "model": model,
  87. "max_tokens": max_tokens,
  88. "stream": True
  89. }
  90. async with StreamSession(**cls._args) as session:
  91. async with session.post(
  92. cls.api_endpoint,
  93. json=data,
  94. ) as response:
  95. cls._args["cookies"] = merge_cookies(cls._args["cookies"] , response)
  96. try:
  97. await raise_for_status(response)
  98. except ResponseStatusError:
  99. cls._args = None
  100. if cache_file.exists():
  101. cache_file.unlink()
  102. raise
  103. reason = None
  104. async for line in response.iter_lines():
  105. if line.startswith(b'data: '):
  106. if line == b'data: [DONE]':
  107. break
  108. try:
  109. content = json.loads(line[6:].decode())
  110. if content.get("response") and content.get("response") != '</s>':
  111. yield content['response']
  112. reason = "max_tokens"
  113. elif content.get("response") == '':
  114. reason = "stop"
  115. except Exception:
  116. continue
  117. if reason is not None:
  118. yield FinishReason(reason)
  119. with cache_file.open("w") as f:
  120. json.dump(cls._args, f)