Cloudflare.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. from __future__ import annotations
  2. import asyncio
  3. import json
  4. from ..typing import AsyncResult, Messages, Cookies
  5. from .base_provider import AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin, get_running_loop
  6. from ..requests import Session, StreamSession, get_args_from_nodriver, raise_for_status, merge_cookies
  7. from ..requests import DEFAULT_HEADERS, has_nodriver, has_curl_cffi
  8. from ..providers.response import FinishReason, Usage
  9. from ..errors import ResponseStatusError, ModelNotFoundError
  10. class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin):
  11. label = "Cloudflare AI"
  12. url = "https://playground.ai.cloudflare.com"
  13. working = True
  14. use_nodriver = True
  15. api_endpoint = "https://playground.ai.cloudflare.com/api/inference"
  16. models_url = "https://playground.ai.cloudflare.com/api/models"
  17. supports_stream = True
  18. supports_system_message = True
  19. supports_message_history = True
  20. default_model = "@cf/meta/llama-3.3-70b-instruct-fp8-fast"
  21. model_aliases = {
  22. "llama-2-7b": "@cf/meta/llama-2-7b-chat-fp16",
  23. "llama-2-7b": "@cf/meta/llama-2-7b-chat-int8",
  24. "llama-3-8b": "@cf/meta/llama-3-8b-instruct",
  25. "llama-3-8b": "@cf/meta/llama-3-8b-instruct-awq",
  26. "llama-3-8b": "@hf/meta-llama/meta-llama-3-8b-instruct",
  27. "llama-3.1-8b": "@cf/meta/llama-3.1-8b-instruct-awq",
  28. "llama-3.1-8b": "@cf/meta/llama-3.1-8b-instruct-fp8",
  29. "llama-3.2-1b": "@cf/meta/llama-3.2-1b-instruct",
  30. "qwen-1.5-7b": "@cf/qwen/qwen1.5-7b-chat-awq",
  31. }
  32. _args: dict = None
  33. @classmethod
  34. def get_models(cls) -> str:
  35. if not cls.models:
  36. if cls._args is None:
  37. if has_nodriver:
  38. get_running_loop(check_nested=True)
  39. args = get_args_from_nodriver(cls.url)
  40. cls._args = asyncio.run(args)
  41. elif not has_curl_cffi:
  42. return cls.models
  43. else:
  44. cls._args = {"headers": DEFAULT_HEADERS, "cookies": {}}
  45. with Session(**cls._args) as session:
  46. response = session.get(cls.models_url)
  47. cls._args["cookies"] = merge_cookies(cls._args["cookies"], response)
  48. try:
  49. raise_for_status(response)
  50. except ResponseStatusError:
  51. return cls.models
  52. json_data = response.json()
  53. cls.models = [model.get("name") for model in json_data.get("models")]
  54. return cls.models
  55. @classmethod
  56. async def create_async_generator(
  57. cls,
  58. model: str,
  59. messages: Messages,
  60. proxy: str = None,
  61. max_tokens: int = 2048,
  62. cookies: Cookies = None,
  63. timeout: int = 300,
  64. **kwargs
  65. ) -> AsyncResult:
  66. cache_file = cls.get_cache_file()
  67. if cls._args is None:
  68. if cache_file.exists():
  69. with cache_file.open("r") as f:
  70. cls._args = json.load(f)
  71. elif has_nodriver:
  72. cls._args = await get_args_from_nodriver(cls.url, proxy, timeout, cookies)
  73. else:
  74. cls._args = {"headers": DEFAULT_HEADERS, "cookies": {}}
  75. try:
  76. model = cls.get_model(model)
  77. except ModelNotFoundError:
  78. pass
  79. data = {
  80. "messages": [{
  81. **message,
  82. "content": message["content"] if isinstance(message["content"], str) else "",
  83. "parts": [{"type":"text", "text":message["content"]}] if isinstance(message["content"], str) else message} for message in messages],
  84. "lora": None,
  85. "model": model,
  86. "max_tokens": max_tokens,
  87. "stream": True,
  88. "system_message":"You are a helpful assistant",
  89. "tools":[]
  90. }
  91. async with StreamSession(**cls._args) as session:
  92. async with session.post(
  93. cls.api_endpoint,
  94. json=data,
  95. ) as response:
  96. cls._args["cookies"] = merge_cookies(cls._args["cookies"] , response)
  97. try:
  98. await raise_for_status(response)
  99. except ResponseStatusError:
  100. cls._args = None
  101. if cache_file.exists():
  102. cache_file.unlink()
  103. raise
  104. async for line in response.iter_lines():
  105. if line.startswith(b'0:'):
  106. yield json.loads(line[2:])
  107. elif line.startswith(b'e:'):
  108. finish = json.loads(line[2:])
  109. yield Usage(**finish.get("usage"))
  110. yield FinishReason(finish.get("finishReason"))
  111. with cache_file.open("w") as f:
  112. json.dump(cls._args, f)