Cloudflare.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. from __future__ import annotations
  2. import asyncio
  3. import json
  4. from ..typing import AsyncResult, Messages, Cookies
  5. from .base_provider import AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin, get_running_loop
  6. from ..requests import Session, StreamSession, get_args_from_nodriver, raise_for_status, merge_cookies
  7. from ..requests import DEFAULT_HEADERS, has_nodriver, has_curl_cffi
  8. from ..providers.response import FinishReason, Usage
  9. from ..errors import ResponseStatusError, ModelNotFoundError
  10. from .. import debug
  11. from .helper import render_messages
  12. class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin):
  13. label = "Cloudflare AI"
  14. url = "https://playground.ai.cloudflare.com"
  15. working = has_curl_cffi
  16. use_nodriver = True
  17. api_endpoint = "https://playground.ai.cloudflare.com/api/inference"
  18. models_url = "https://playground.ai.cloudflare.com/api/models"
  19. supports_stream = True
  20. supports_system_message = True
  21. supports_message_history = True
  22. default_model = "@cf/meta/llama-3.3-70b-instruct-fp8-fast"
  23. model_aliases = {
  24. "llama-2-7b": "@cf/meta/llama-2-7b-chat-fp16",
  25. "llama-2-7b": "@cf/meta/llama-2-7b-chat-int8",
  26. "llama-3-8b": "@cf/meta/llama-3-8b-instruct",
  27. "llama-3-8b": "@cf/meta/llama-3-8b-instruct-awq",
  28. "llama-3-8b": "@hf/meta-llama/meta-llama-3-8b-instruct",
  29. "llama-3.1-8b": "@cf/meta/llama-3.1-8b-instruct-awq",
  30. "llama-3.1-8b": "@cf/meta/llama-3.1-8b-instruct-fp8",
  31. "llama-3.2-1b": "@cf/meta/llama-3.2-1b-instruct",
  32. "llama-4-scout": "@cf/meta/llama-4-scout-17b-16e-instruct",
  33. "deepseek-math-7b": "@cf/deepseek-ai/deepseek-math-7b-instruct",
  34. "deepseek-r1-qwen-32b": "@cf/deepseek-ai/deepseek-r1-distill-qwen-32b",
  35. "falcon-7b": "@cf/tiiuae/falcon-7b-instruct",
  36. "qwen-1.5-7b": "@cf/qwen/qwen1.5-7b-chat-awq",
  37. "qwen-2.5-coder": "@cf/qwen/qwen2.5-coder-32b-instruct",
  38. }
  39. fallback_models = list(model_aliases.keys())
  40. _args: dict = None
  41. @classmethod
  42. def get_models(cls) -> str:
  43. def read_models():
  44. with Session(**cls._args) as session:
  45. response = session.get(cls.models_url)
  46. cls._args["cookies"] = merge_cookies(cls._args["cookies"], response)
  47. raise_for_status(response)
  48. json_data = response.json()
  49. def clean_name(name: str) -> str:
  50. return name.split("/")[-1].replace(
  51. "-instruct", "").replace(
  52. "-17b-16e", "").replace(
  53. "-chat", "").replace(
  54. "-fp8", "").replace(
  55. "-fast", "").replace(
  56. "-int8", "").replace(
  57. "-awq", "").replace(
  58. "-qvq", "").replace(
  59. "-r1", "").replace(
  60. "meta-llama-", "llama-")
  61. model_map = {clean_name(model.get("name")): model.get("name") for model in json_data.get("models")}
  62. cls.models = list(model_map.keys())
  63. cls.model_aliases = {**cls.model_aliases, **model_map}
  64. if not cls.models:
  65. try:
  66. cache_file = cls.get_cache_file()
  67. if cls._args is None:
  68. if cache_file.exists():
  69. with cache_file.open("r") as f:
  70. cls._args = json.load(f)
  71. if cls._args is None:
  72. cls._args = {"headers": DEFAULT_HEADERS, "cookies": {}}
  73. read_models()
  74. except ResponseStatusError as f:
  75. if has_nodriver:
  76. async def nodriver_read_models():
  77. try:
  78. cls._args = await get_args_from_nodriver(cls.url)
  79. read_models()
  80. except Exception as e:
  81. debug.log(f"Nodriver is not available: {type(e).__name__}: {e}")
  82. cls.models = cls.fallback_models
  83. get_running_loop(check_nested=True)
  84. try:
  85. asyncio.run(nodriver_read_models())
  86. except RuntimeError:
  87. debug.log("Nodriver is not available: RuntimeError")
  88. cls.models = cls.fallback_models
  89. else:
  90. cls.models = cls.fallback_models
  91. debug.log(f"Nodriver is not installed: {type(f).__name__}: {f}")
  92. return cls.models
  93. @classmethod
  94. async def create_async_generator(
  95. cls,
  96. model: str,
  97. messages: Messages,
  98. proxy: str = None,
  99. max_tokens: int = 2048,
  100. **kwargs
  101. ) -> AsyncResult:
  102. cache_file = cls.get_cache_file()
  103. if cls._args is None:
  104. if cache_file.exists():
  105. with cache_file.open("r") as f:
  106. cls._args = json.load(f)
  107. elif has_nodriver:
  108. try:
  109. cls._args = await get_args_from_nodriver(cls.url, proxy=proxy)
  110. except (RuntimeError, FileNotFoundError) as e:
  111. debug.log(f"Nodriver is not available: {type(e).__name__}: {e}")
  112. cls._args = {"headers": DEFAULT_HEADERS, "cookies": {}, "impersonate": "chrome"}
  113. else:
  114. cls._args = {"headers": DEFAULT_HEADERS, "cookies": {}, "impersonate": "chrome"}
  115. try:
  116. model = cls.get_model(model)
  117. except ModelNotFoundError:
  118. pass
  119. data = {
  120. "messages": [{
  121. **message,
  122. "parts": [{"type":"text", "text": message["content"]}]} for message in render_messages(messages)],
  123. "lora": None,
  124. "model": model,
  125. "max_tokens": max_tokens,
  126. "stream": True,
  127. "system_message":"You are a helpful assistant",
  128. "tools":[]
  129. }
  130. async with StreamSession(**cls._args) as session:
  131. async with session.post(
  132. cls.api_endpoint,
  133. json=data,
  134. ) as response:
  135. cls._args["cookies"] = merge_cookies(cls._args["cookies"] , response)
  136. try:
  137. await raise_for_status(response)
  138. except ResponseStatusError:
  139. cls._args = None
  140. if cache_file.exists():
  141. cache_file.unlink()
  142. raise
  143. async for line in response.iter_lines():
  144. if line.startswith(b'0:'):
  145. yield json.loads(line[2:])
  146. elif line.startswith(b'e:'):
  147. finish = json.loads(line[2:])
  148. yield Usage(**finish.get("usage"))
  149. yield FinishReason(finish.get("finishReason"))
  150. with cache_file.open("w") as f:
  151. json.dump(cls._args, f)